import torch
import wandb
import numpy as np
import pytorch_lightning as pl
from trainers.trainer_simple import LitModel as LitSimple, sample_images_at_mc_locs, huber
from model import zoo as models
from model import raymarcher_lebesgue as raymarchers
from model.raymarcher_lebesgue import ImplicitRendererDict

from pytorch3d.renderer import (
    NDCGridRaysampler,
    MonteCarloRaysampler
)


class LitModel(LitSimple):
    def __init__(self, hparams):
        super(LitSimple, self).__init__()
        hparams.model.density.max_depth = hparams.data_conf.max_depth
        self.save_hyperparameters(hparams)
        self.debug = hparams.get('debug', False)

        Nerf = getattr(models, self.hparams.get('model_name', 'NeRFLebesgue'), 'NeRFLebesgue')

        self.nerf = Nerf(
            self.hparams.model
        )
            
        # 1) Instantiate the raysamplers.
        # Here, NDCGridRaysampler generates a rectangular image
        # grid of rays whose coordinates follow the PyTorch3D
        # coordinate conventions.
        # here depth is resposible for value of cdf
        self.raysampler_grid = NDCGridRaysampler(
            image_height=self.hparams.data_conf.render_height,
            image_width=self.hparams.data_conf.render_width,
            n_pts_per_ray=self.hparams.data.n_pts_per_ray,
            min_depth=self.hparams.data_conf.get('min_depth', 0),
            max_depth=self.hparams.data_conf.max_depth,
        )

        # MonteCarloRaysampler generates a random subset 
        # of `n_rays_per_image` rays emitted from the image plane.
        # here depth is resposible for value of cdf
        self.raysampler_mc = MonteCarloRaysampler(
            min_x = -1.0,
            max_x = 1.0,
            min_y = -1.0,
            max_y = 1.0,
            n_rays_per_image=self.hparams.train.n_rays_per_image,
            n_pts_per_ray=self.hparams.train.n_pts_per_ray,
            min_depth=self.hparams.data_conf.get('min_depth', 0),
            max_depth=self.hparams.data_conf.max_depth,
        )
        self.raymarcher = getattr(raymarchers, self.hparams.model.get('raymarcher', 'LebesgueRaymarcher'))()

        # Finally, instantiate the implicit renders
        # for both raysamplers.
        self.renderer_grid = ImplicitRendererDict(
            raysampler=self.raysampler_grid, raymarcher=self.raymarcher, stratified_resamling=self.hparams.train.get('stratified_sampling', False)
        )
        self.renderer_mc = ImplicitRendererDict(
            raysampler=self.raysampler_mc, raymarcher=self.raymarcher, stratified_resamling=self.hparams.train.get('stratified_sampling', False)
        )

    def on_train_epoch_start(self):
        super().on_train_epoch_start()
        if self.nerf.hparams.color.get('color_epochs_pretrain', -1) >= self.current_epoch:
            self.nerf.hparams.color.n_color_samples = self.hparams.color.n_color_samples_after_pretrain
        self.on_train_step_start()

    def training_step(self, batch, batch_idx):
        # Evaluate the nerf model.
        rendered_images_silhouettes, sampled_rays, rays_features_dict = self.renderer_mc(
            cameras=batch['cameras'], 
            volumetric_function=self.nerf,
            nerf_opacity=self.nerf,
            max_depth=self.hparams.model.density.max_depth 
        )
        if 'target_silhouettes' in batch and self.hparams.train.get('silhouette_weight', 1) > 0:
            rendered_images, rendered_silhouettes = (
                rendered_images_silhouettes.split([3, 1], dim=-1)
            )
            
            # Compute the silhouette error as the mean huber
            # loss between the predicted masks and the
            # sampled target silhouettes.
            # print(sampled_rays.xys.shape)
            silhouettes_at_rays = sample_images_at_mc_locs(
                batch['target_silhouettes'][..., None], 
                sampled_rays.xys
            )
            sil_err = huber(
                rendered_silhouettes, 
                silhouettes_at_rays,
            ).abs().mean()
        else:
            sil_err = None
            if rendered_images_silhouettes.shape[-1] == 4:
                rendered_images, _ = (
                    rendered_images_silhouettes.split([3, 1], dim=-1)
                )
            else:
                rendered_images = rendered_images_silhouettes

        # Compute the color error as the mean huber
        # loss between the rendered colors and the
        # sampled target images.
        colors_at_rays = sample_images_at_mc_locs(
            batch['target_images'], 
            sampled_rays.xys
        )
        if self.hparams.train.get('add_noise', False):
            colors_at_rays = colors_at_rays + torch.rand_like(colors_at_rays) / 255.0
        color_err = huber(
            rendered_images, 
            colors_at_rays,
        ).abs().mean()

        # The optimization loss is a simple
        # sum of the color and silhouette errors.
        loss = color_err
        self.log('train/color_err', color_err.item())
        if sil_err is not None:
            loss += self.hparams.train.get('silhouette_weight', 1) * sil_err
            self.log('train/sil_err', sil_err.item())
        loss = loss + self.l1_l2_reg()
        self.log('train/loss_total', loss.item())

        if self.hparams.train.get('log_t_values', False) and batch_idx % 10 == 0:
            with torch.no_grad():
                t_hist = np.histogram(rays_features_dict['t_grid'].reshape(-1, 1).cpu().numpy(), bins=100)
                y_hist = np.histogram(rays_features_dict['grid'].reshape(-1, 1).cpu().numpy(), bins=100)
                y_min_hist = np.histogram(rays_features_dict['y_min'].reshape(-1, 1).cpu().numpy(), bins=100)
                y_max_hist = np.histogram(rays_features_dict['opacity'].reshape(-1, 1).cpu().numpy(), bins=100)
                c_hist = np.histogram(rays_features_dict['avg_values'].reshape(-1, 1).cpu().numpy(), bins=100)
                self.logger.experiment.log({
                    'train/hist_t_grid': wandb.Histogram(np_histogram=t_hist),
                    'train/hist_y_grid': wandb.Histogram(np_histogram=y_hist),
                    'train/hist_y_min': wandb.Histogram(np_histogram=y_min_hist),
                    'train/hist_y_max': wandb.Histogram(np_histogram=y_max_hist),
                    'train/color': wandb.Histogram(np_histogram=c_hist),
                }, step=self.global_step)
        return loss
    
    def validation_step(self, batch, batch_idx):
        if self.hparams.data.get('n_color_samples', None) is not None:
            prev_n_color_samples = self.nerf.hparams.color.n_color_samples
            self.nerf.hparams.color.n_color_samples = self.hparams.data.n_color_samples # set validation color samples
        out = super(LitModel, self).validation_step(batch, batch_idx)
        if self.hparams.data.get('n_color_samples', None) is not None:
            self.nerf.hparams.color.n_color_samples = prev_n_color_samples # use training number of color samples
        return out
